--- ../../plenoxels/opt/util/util.py	2024-02-13 12:06:07.930166019 -0500
+++ ../google/plenoxels/opt/util/util.py	2024-02-13 10:46:24.452645156 -0500
@@ -1,3 +1,4 @@
+import cv2
 import torch
 import torch.cuda
 import torch.nn.functional as F
@@ -7,7 +8,7 @@
 import cv2
 from scipy.spatial.transform import Rotation
 from scipy.interpolate import CubicSpline
-from matplotlib import pyplot as plt
+# from matplotlib import pyplot as plt
 from warnings import warn
 
 
@@ -118,8 +119,8 @@
     :param gray: np.ndarray, (H, W) or (H, W, 1) unscaled
     :return: (H, W, 3) float32 in [0, 1]
     """
-    colored = plt.cm.viridis(plt.Normalize()(gray.squeeze()))[..., :-1]
-    return colored.astype(np.float32)
+    #colored = plt.cm.viridis(plt.Normalize()(gray.squeeze()))[..., :-1]
+    return np.array(gray).astype(np.float32)
 
 
 def save_img(img: np.ndarray, path: str):
@@ -476,3 +477,62 @@
     if offset is not None:
         c2w[:3, 3] += offset
     return c2w
+
+
+def convert_rays_plenoxels_to_plenoctre(rays, scene_scale):
+    raise Exception("The camera convention changes, but the rays in 3D space are the same, So no need to transform rays from one to the other if they are not in NDC coordinates!")
+    assert str(type(rays)).endswith(".Rays'>")
+    assert scene_scale == 1.0, "To be implemented for scene_scale != 1"
+    
+    assert np.all(np.array((rays.viewdirs == rays.dirs).cpu())), "Not implemented for ndc sceenes"
+    
+    cam_trans = torch.diag(torch.tensor([1, -1, -1], dtype=torch.float32))
+    
+    origins = rays.origins @ cam_trans.to(rays.origins)
+    dirs = rays.dirs @ cam_trans.to(rays.origins)
+    viewdirs = rays.viewdirs @ cam_trans.to(rays.origins)
+    
+    return origins, dirs, viewdirs
+    
+def convert_c2w_plenoxel_to_pleonctree(c2w, scene_scale):
+    if type(c2w) == np.ndarray:
+        is_np = True
+        c2w = torch.FloatTensor(c2w)
+    else:
+        is_np = False
+
+    # OpenGL -> OpenCV
+    cam_trans = torch.diag(torch.tensor([1, -1, -1, 1], dtype=torch.float32))
+    c2w_converted = c2w.clone() # make a copy
+    c2w_converted[:3, 3] /= scene_scale
+    c2w_converted = c2w_converted @ cam_trans.to(c2w.device)
+    if is_np:
+        return np.array(c2w_converted.cpu())
+    else:
+        return c2w_converted.to(c2w.device)
+    
+
+
+def cv2_resize(image, target_height_width, interpolation=cv2.INTER_NEAREST):
+  if len(image.shape) == 2:
+    return cv2.resize(image, target_height_width[::-1], interpolation=interpolation)
+  else:
+    return cv2.resize(image.transpose((1, 2, 0)), target_height_width[::-1], interpolation=interpolation).transpose((2, 0, 1))
+
+
+def construct_mask_mae(masking_ratios, img_size, patch_size, random_seed):
+  assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, "img_size should be multiple of patch_size!"
+  previous_state = np.random.get_state()
+  np.random.seed(random_seed)
+
+  mask = np.random.uniform(0, 1, size=(int(img_size[0]/patch_size),int(img_size[1]/patch_size)))
+  if type(masking_ratios) != float and masking_ratios[0] != masking_ratios[1]:
+    masking_ratio = np.random.uniform(masking_ratios[0], masking_ratios[1])
+  else:
+    masking_ratio = masking_ratios
+  np.random.set_state(previous_state)
+
+  masked_region = mask < masking_ratio
+  active_mask = 1 - cv2_resize(masked_region.astype('float32'), img_size, interpolation=cv2.INTER_NEAREST)
+
+  return active_mask
